-
Notifications
You must be signed in to change notification settings - Fork 1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Dmoe integration #1210
base: main
Are you sure you want to change the base?
Dmoe integration #1210
Conversation
- Removed mp assertion for moe - Removed mlp_type checks in moe code - Added Bf16 conversion to dmoe_gather
megatron/mpu/mappings.py
Outdated
@@ -185,9 +180,102 @@ def _dmoe_gather(input_: torch.Tensor, tokens_per_expert: torch.Tensor): | |||
# Note: torch.cat already creates a contiguous tensor. | |||
output = torch.cat(tensor_list, dim=gather_dim) | |||
|
|||
# Bf16 convert |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Removing this results in fp32 output.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was resolved in the latest commit.
Profiles before and after the merge: |
@@ -67,35 +58,26 @@ | |||
|
|||
# regularization | |||
"gradient_clipping": 1.0, | |||
"weight_decay": 0.1, | |||
"weight_decay": 0.0, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there appear to be a lot of extraneous config changes. Any reason why?
@@ -1075,15 +1077,6 @@ def calculate_derived(self): | |||
# if we set pipe_parallel_size to 0, GPT2ModelPipe.to_sequential() is called, and we run training with | |||
# the sequential model without the PipelineModule wrapper to avoid the overhead it incurs | |||
self.update_value("is_pipe_parallel", self.pipe_parallel_size >= 1) | |||
if self.moe_num_experts > 1: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we test these parallelism combinations?
Supersedes #1197
This PR adds dropless MoE support using the Grouped GEMM implementation in megablocks.
Features
Unlike the legacy DeepSpeed MoE implementation that uses the data parallel groups for expert parallelism, this implementation uses the model parallel group to parallelize the experts. This avoids the following problems:
Clarified arguments by separating MoE args into their own class.
Use sinkhorn routing by default, support k>=1. TopK routing is used for evaluation/inference.
Testing
Tested PP [3, 2, 1] and MP [1, 2, 4, 8] on Ampere GPUs.
Notes
Added megablocks and grouped_gemm to the dependencies. It might be desirable to pull some of the kernels in directly like in NVIDIA megatron-core.